-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add Distillation API to Keras #21572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Distillation API to Keras #21572
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new Knowledge Distillation API to Keras, designed to facilitate the efficient transfer of learned knowledge from larger, pre-trained "teacher" models to smaller "student" models. The API seamlessly integrates with Keras's existing training, evaluation, and prediction workflows, providing a flexible and extensible framework for various distillation techniques.
Highlights
- New Distiller Model: A core Distiller class is added, which is a keras.Model subclass, enabling the combination and training of teacher and student models within the standard Keras workflow.
- Pluggable Distillation Strategies: Introduces a BaseDistillationStrategy and three concrete implementations: LogitsDistillation (for softening logits), FeatureDistillation (for intermediate feature matching), and MultiOutputDistillation (for handling models with multiple outputs).
- Configurable Loss Balancing: The Distiller allows specifying an alpha parameter to balance the contribution of the student's original loss and the distillation loss.
- Automatic Teacher Freezing: The teacher model is automatically set to non-trainable (trainable=False) during the distillation process to prevent its weights from being updated.
- Comprehensive Testing: New test files (distiller_test.py and strategies_test.py) are added to ensure the robustness and correctness of the new API, covering initialization, loss computation, and end-to-end workflows.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller model and pluggable strategies. My review has identified a few issues: a critical issue with the FeatureDistillation strategy which is not fully implemented, a high-severity issue with an unused temperature parameter in the Distiller class that could mislead users, and a medium-severity issue regarding a simplistic fallback for loss calculation in multi-output scenarios. Addressing these points will improve the robustness and clarity of this new API.
keras/src/distillation/distiller.py
Outdated
| if isinstance(y_pred, (list, tuple)): | ||
| # For multi-output models, use the first output for student | ||
| # loss | ||
| # This is a simplified approach for compatibility | ||
| if isinstance(y, (list, tuple)): | ||
| student_loss = self.student_loss_fn(y[0], y_pred[0]) | ||
| else: | ||
| student_loss = self.student_loss_fn(y, y_pred[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback logic for calculating the student loss in _compute_loss for multi-output models is overly simplistic as it always defaults to using the first output (y_pred[0]). This might not align with user expectations for all multi-output scenarios and could lead to incorrect training behavior if model.compile() is not called with a loss that properly handles multiple outputs.
While the primary path using self.compiled_loss is correct, this fallback could be made more robust. Consider raising a more specific error if a multi-output model is used without a compiled loss, or clarifying this behavior more explicitly in the documentation.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a well-structured Knowledge Distillation API to Keras, a valuable addition for model compression and knowledge transfer. The implementation of the Distiller model and pluggable strategies is clean and the code is generally of high quality, with good documentation and examples.
I have identified a few significant issues that should be addressed to improve the robustness and completeness of the API:
- A critical flaw in
FeatureDistillationlimits its use to sequential models, which will prevent it from working with more complex architectures. - The
Distiller's loss computation for multi-output models has a bug in its fallback logic. - The
Distillermodel cannot be correctly serialized and deserialized due to an incompleteget_configimplementation. - The tests for
FeatureDistillationare not comprehensive enough to cover non-sequential model architectures.
Addressing these points will significantly enhance the reliability and usability of this new API. Overall, this is a great contribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller model and pluggable DistillationStrategy classes. The code is generally clean and includes a good set of tests.
I've identified a few key areas for improvement:
- Serialization: The
DistillerandMultiOutputDistillationclasses are not correctly configured for serialization, which will prevent models using them from being saved and loaded. This is a critical feature in Keras that needs to be addressed. - Feature Extraction Robustness: The
FeatureDistillationstrategy uses a method for extracting intermediate layer features that is not robust and will fail for models with non-sequential architectures (e.g., ResNets). This limitation needs to be documented and ideally improved. - Code Simplification: There's a small piece of unreachable code in the
Distiller's loss computation that can be simplified.
Addressing these points will significantly improve the robustness and usability of this new API.
keras/src/distillation/distiller.py
Outdated
| if isinstance(y_pred, list) and len(y_pred) > 0: | ||
| # For multi-output, use first output for student loss | ||
| student_loss = self.student_loss_fn(y[0], y_pred[0]) | ||
| else: | ||
| student_loss = self.student_loss_fn(y, y_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The isinstance(y_pred, list) check on line 283 is redundant because y_pred is converted to a list on line 263. This makes the else block on line 286 unreachable. The logic can be simplified to directly use the first output for the student loss calculation.
# Fallback: use student_loss_fn directly
# For multi-output, use first output for student loss
student_loss = self.student_loss_fn(y[0], y_pred[0])|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive and well-designed Knowledge Distillation API to Keras. The implementation is robust, featuring a flexible Distiller class and a set of pluggable distillation strategies that cover common use cases like logits and feature distillation, as well as multi-output models. The code is accompanied by extensive and thorough tests, which is excellent. My feedback includes a couple of suggestions to improve code style in the API files and to enhance the robustness of a test case by removing a broad exception handler. Overall, this is a high-quality contribution that will be a valuable addition to Keras.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21572 +/- ##
==========================================
- Coverage 82.69% 82.65% -0.05%
==========================================
Files 573 577 +4
Lines 58888 59189 +301
Branches 9218 9277 +59
==========================================
+ Hits 48696 48921 +225
- Misses 7845 7891 +46
- Partials 2347 2377 +30
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Some quick comments on the API.
keras/src/distillation/distiller.py
Outdated
| # Re-raise with context about which strategy failed | ||
| raise RuntimeError( | ||
| f"Failed to extract features for " | ||
| f"FeatureDistillation targeting teacher layer " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick {type(strategy)} instead of FeatureDistillation in case people subclass it.
| # Ensure student_loss is a scalar | ||
| if hasattr(student_loss, "shape") and len(student_loss.shape) > 0: | ||
| student_loss = keras.ops.mean(student_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The strategy loss raises an error line 504 if it's not scalar. Any reason to handle these differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The student loss fn returns per-sample losses by default (shape, (batch_size,)) - the code converts to scalars with keras.ops.mean
But with stategy.compute_loss() method, it is supposed to return scalar losses.
| ) | ||
|
|
||
| # Verify that teacher and student outputs have the same structure | ||
| keras.tree.assert_same_structure(teacher_features, student_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is not using FeatureDistillation at all. You should instantiate a FeatureDistillation and call the validation method instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new Knowledge Distillation API to Keras, which is a valuable addition. The API design is clean and follows the end-to-end workflow principle from the Keras design guidelines. The core components like Distiller, LogitsDistillation, and FeatureDistillation are well-structured.
My review focuses on a few key areas:
- Correctness: I found a high-severity issue in the
FeatureDistillationloss calculation for cosine similarity which would lead to incorrect training behavior. - Maintainability: There are opportunities to make the code more robust, for example by using
isinstancefor type checking instead of string matching on names. - Style Guide Adherence: I've pointed out several minor violations of the Keras API design guidelines regarding docstring formatting, specifically the naming of
Args:andExamples:sections.
The implementation is solid overall, with good validation and efficient feature extraction for multi-strategy distillation. The accompanying tests are also quite comprehensive. Addressing the feedback will improve the correctness and maintainability of this new API.
keras/src/distillation/distiller.py
Outdated
| The teacher model is frozen during distillation. | ||
| student: A `keras.Model` to be trained through distillation. | ||
| strategies: List of distillation strategies to apply. Can be a single | ||
| strategy or a list of strategies like `LogitsDistillation`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"strategy instances like keras.distillation.LogitsDistillation()..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an inconsistency in terminology, where these are sometimes referred to as losses and sometimes as strategies. Which is it? (Could be both: training_strategy or something)
We should make the arg name here consistent with the object class name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these are DistillationLoss subclasses then the arg should be distillation_losses
A potential issue with calling them losses is that they're very different from keras.losses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done! renamed everything to distillation loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Main request from me is to fix the naming inconsistency -- pick a standard class name for distillation losses / strategies and then always refer to instances of the class by that name, e.g. distillation_losses. No strong opinion on what the class name should be.
|
I have addressed all teh comments. Updated the name to be consistent - distillation loss everywhere instead of strategy. I will merge the PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bulk replace did some weird stuff, in particular in the docstrings.
But also, the main question is whether the Distiller argument should be called distillation_losses (plural.).
| student_outputs: Outputs from the student model. Can be a single | ||
| tensor or a list/tuple of tensors for multi-output models. | ||
| **kwargs: Additional arguments for custom strategies. | ||
| **kwargs: Additional arguments for custom distillation_loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distillation losses.
| Raises: | ||
| ValueError: If models are not compatible with this strategy. | ||
| ValueError: If models are not compatible with this | ||
| distillation_loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distillation loss.
| @keras_export("keras.distillation.FeatureDistillation") | ||
| class FeatureDistillation(DistillationLoss): | ||
| """Feature distillation strategy using intermediate layer representations. | ||
| """Feature distillation distillation_loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feature distillation loss.
| This strategy applies temperature scaling to the teacher's logits before | ||
| computing the loss between teacher and student predictions. It's the most | ||
| common approach for knowledge distillation. | ||
| This distillation_loss applies temperature scaling to the teacher's logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This distillation loss...
| @pytest.mark.requires_trainable_backend | ||
| class TestLogitsDistillation(TestCase): | ||
| """Test cases for LogitsDistillation strategy.""" | ||
| """Test cases for LogitsDistillation distillation_loss.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove "distillation_loss" (actually, you can remove the whole line, it's not needed).
keras/src/distillation/distiller.py
Outdated
| for strategy in self.strategies: | ||
| for strategy in self.distillation_loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for distillation_loss in self.distillation_losses:
| Arguments: | ||
| model: The model to create an extractor for. | ||
| layer_names: List of layer names to extract features from. | ||
| Raises: | ||
| ValueError: If model has no symbolic inputs/outputs. | ||
| Returns: | ||
| Feature extractor model or `None` if no layer names | ||
| sprovided. | ||
| ` | ||
| Raises: | ||
| ValueError: If model has no symbolic inputs/outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indentation seems off.
| for strategy, weight in zip( | ||
| self.distillation_loss, self.distillation_loss_weights | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for distillation_loss, weight in zip(self.distillation_losses...
| self.student | ||
| ), | ||
| "strategies": [ | ||
| "distillation_loss": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distillation_losses?
| config["student"] | ||
| ) | ||
| config["strategies"] = [ | ||
| config["distillation_loss"] = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distillation_losses?
This PR adds Knowledge Distillation API to Keras,
Key Features
Core Components
Usage Examples
Basic Knowledge Distillation